{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# How to use OpenNMT-py as a Library"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The example notebook (available [here](https://github.com/OpenNMT/OpenNMT-py/blob/master/docs/source/examples/Library.ipynb)) should be able to run as a standalone execution, provided `onmt` is in the path (installed via `pip` for instance).\n",
    "\n",
    "Some parts may not be 100% 'library-friendly' but it's mostly workable."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Import a few modules and functions that will be necessary"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import yaml\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "from argparse import Namespace\n",
    "from collections import defaultdict, Counter"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import onmt\n",
    "from onmt.inputters.inputter import _load_vocab, _build_fields_vocab, get_fields, IterOnDevice\n",
    "from onmt.inputters.corpus import ParallelCorpus\n",
    "from onmt.inputters.dynamic_iterator import DynamicDatasetIter\n",
    "from onmt.translate import GNMTGlobalScorer, Translator, TranslationBuilder\n",
    "from onmt.utils.misc import set_random_seed"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Enable logging"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<RootLogger root (INFO)>"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# enable logging\n",
    "from onmt.utils.logging import init_logger, logger\n",
    "init_logger()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Set random seed"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "is_cuda = torch.cuda.is_available()\n",
    "set_random_seed(1111, is_cuda)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Retrieve data"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To make a proper example, we will need some data, as well as some vocabulary(ies).\n",
    "\n",
    "Let's take the same data as in the [quickstart](https://opennmt.net/OpenNMT-py/quickstart.html):"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--2020-09-25 15:28:05--  https://s3.amazonaws.com/opennmt-trainingdata/toy-ende.tar.gz\n",
      "Resolving s3.amazonaws.com (s3.amazonaws.com)... 52.217.18.38\n",
      "Connecting to s3.amazonaws.com (s3.amazonaws.com)|52.217.18.38|:443... connected.\n",
      "HTTP request sent, awaiting response... 200 OK\n",
      "Length: 1662081 (1,6M) [application/x-gzip]\n",
      "Saving to: ‘toy-ende.tar.gz.5’\n",
      "\n",
      "toy-ende.tar.gz.5   100%[===================>]   1,58M  2,33MB/s    in 0,7s    \n",
      "\n",
      "2020-09-25 15:28:07 (2,33 MB/s) - ‘toy-ende.tar.gz.5’ saved [1662081/1662081]\n",
      "\n"
     ]
    }
   ],
   "source": [
    "!wget https://s3.amazonaws.com/opennmt-trainingdata/toy-ende.tar.gz"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "!tar xf toy-ende.tar.gz"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "config.yaml  src-test.txt   src-val.txt   tgt-train.txt\r\n",
      "\u001b[0m\u001b[01;34mrun\u001b[0m/         src-train.txt  tgt-test.txt  tgt-val.txt\r\n"
     ]
    }
   ],
   "source": [
    "ls toy-ende"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Prepare data and vocab"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "As for any use case of OpenNMT-py 2.0, we can start by creating a simple YAML configuration with our datasets. This is the easiest way to build the proper `opts` `Namespace` that will be used to create the vocabulary(ies)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "yaml_config = \"\"\"\n",
    "## Where the vocab(s) will be written\n",
    "save_data: toy-ende/run/example\n",
    "src_vocab: toy-ende/run/example.vocab.src\n",
    "tgt_vocab: toy-ende/run/example.vocab.tgt\n",
    "# Corpus opts:\n",
    "data:\n",
    "    corpus:\n",
    "        path_src: toy-ende/src-train.txt\n",
    "        path_tgt: toy-ende/tgt-train.txt\n",
    "        transforms: []\n",
    "        weight: 1\n",
    "    valid:\n",
    "        path_src: toy-ende/src-val.txt\n",
    "        path_tgt: toy-ende/tgt-val.txt\n",
    "        transforms: []\n",
    "\"\"\"\n",
    "config = yaml.safe_load(yaml_config)\n",
    "with open(\"toy-ende/config.yaml\", \"w\") as f:\n",
    "    f.write(yaml_config)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "from onmt.utils.parse import ArgumentParser\n",
    "parser = ArgumentParser(description='build_vocab.py')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "from onmt.opts import dynamic_prepare_opts\n",
    "dynamic_prepare_opts(parser, build_vocab_only=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "base_args = ([\"-config\", \"toy-ende/config.yaml\", \"-n_sample\", \"10000\"])\n",
    "opts, unknown = parser.parse_known_args(base_args)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Namespace(config='toy-ende/config.yaml', data=\"{'corpus': {'path_src': 'toy-ende/src-train.txt', 'path_tgt': 'toy-ende/tgt-train.txt', 'transforms': [], 'weight': 1}, 'valid': {'path_src': 'toy-ende/src-val.txt', 'path_tgt': 'toy-ende/tgt-val.txt', 'transforms': []}}\", insert_ratio=0.0, mask_length='subword', mask_ratio=0.0, n_sample=10000, src_onmttok_kwargs=\"{'mode': 'none'}\", tgt_onmttok_kwargs=\"{'mode': 'none'}\", overwrite=False, permute_sent_ratio=0.0, poisson_lambda=0.0, random_ratio=0.0, replace_length=-1, rotate_ratio=0.5, save_config=None, save_data='toy-ende/run/example', seed=-1, share_vocab=False, skip_empty_level='warning', src_seq_length=200, src_subword_model=None, src_subword_type='none', src_vocab=None, subword_alpha=0, subword_nbest=1, switchout_temperature=1.0, tgt_seq_length=200, tgt_subword_model=None, tgt_subword_type='none', tgt_vocab=None, tokendrop_temperature=1.0, tokenmask_temperature=1.0, transforms=[])"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "opts"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[2020-09-25 15:28:08,068 INFO] Parsed 2 corpora from -data.\n",
      "[2020-09-25 15:28:08,069 INFO] Counter vocab from 10000 samples.\n",
      "[2020-09-25 15:28:08,070 INFO] Save 10000 transformed example/corpus.\n",
      "[2020-09-25 15:28:08,070 INFO] corpus's transforms: TransformPipe()\n",
      "[2020-09-25 15:28:08,101 INFO] Loading ParallelCorpus(toy-ende/src-train.txt, toy-ende/tgt-train.txt, align=None)...\n",
      "[2020-09-25 15:28:08,320 INFO] Just finished the first loop\n",
      "[2020-09-25 15:28:08,320 INFO] Counters src:24995\n",
      "[2020-09-25 15:28:08,321 INFO] Counters tgt:35816\n"
     ]
    }
   ],
   "source": [
    "from onmt.bin.build_vocab import build_vocab_main\n",
    "build_vocab_main(opts)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "example.vocab.src  example.vocab.tgt  \u001b[0m\u001b[01;34msample\u001b[0m/\r\n"
     ]
    }
   ],
   "source": [
    "ls toy-ende/run"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We just created our source and target vocabularies, respectively `toy-ende/run/example.vocab.src` and `toy-ende/run/example.vocab.tgt`."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Build fields"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can build the fields from the text files that were just created."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "src_vocab_path = \"toy-ende/run/example.vocab.src\"\n",
    "tgt_vocab_path = \"toy-ende/run/example.vocab.tgt\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[2020-09-25 15:28:08,495 INFO] Loading src vocabulary from toy-ende/run/example.vocab.src\n",
      "[2020-09-25 15:28:08,554 INFO] Loaded src vocab has 24995 tokens.\n",
      "[2020-09-25 15:28:08,562 INFO] Loading tgt vocabulary from toy-ende/run/example.vocab.tgt\n",
      "[2020-09-25 15:28:08,617 INFO] Loaded tgt vocab has 35816 tokens.\n"
     ]
    }
   ],
   "source": [
    "# initialize the frequency counter\n",
    "counters = defaultdict(Counter)\n",
    "# load source vocab\n",
    "_src_vocab, _src_vocab_size = _load_vocab(\n",
    "    src_vocab_path,\n",
    "    'src',\n",
    "    counters)\n",
    "# load target vocab\n",
    "_tgt_vocab, _tgt_vocab_size = _load_vocab(\n",
    "    tgt_vocab_path,\n",
    "    'tgt',\n",
    "    counters)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "# initialize fields\n",
    "src_nfeats, tgt_nfeats = 0, 0 # do not support word features for now\n",
    "fields = get_fields(\n",
    "    'text', src_nfeats, tgt_nfeats)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'src': <onmt.inputters.text_dataset.TextMultiField at 0x7fca93802c50>,\n",
       " 'tgt': <onmt.inputters.text_dataset.TextMultiField at 0x7fca93802f60>,\n",
       " 'indices': <torchtext.data.field.Field at 0x7fca93802940>}"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "fields"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[2020-09-25 15:28:08,699 INFO]  * tgt vocab size: 30004.\n",
      "[2020-09-25 15:28:08,749 INFO]  * src vocab size: 24997.\n"
     ]
    }
   ],
   "source": [
    "# build fields vocab\n",
    "share_vocab = False\n",
    "vocab_size_multiple = 1\n",
    "src_vocab_size = 30000\n",
    "tgt_vocab_size = 30000\n",
    "src_words_min_frequency = 1\n",
    "tgt_words_min_frequency = 1\n",
    "vocab_fields = _build_fields_vocab(\n",
    "    fields, counters, 'text', share_vocab,\n",
    "    vocab_size_multiple,\n",
    "    src_vocab_size, src_words_min_frequency,\n",
    "    tgt_vocab_size, tgt_words_min_frequency)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "An alternative way of creating these fields is to run `onmt_train` without actually training, to just output the necessary files."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Prepare for training: model and optimizer creation"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's get a few fields/vocab related variables to simplify the model creation a bit:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "src_text_field = vocab_fields[\"src\"].base_field\n",
    "src_vocab = src_text_field.vocab\n",
    "src_padding = src_vocab.stoi[src_text_field.pad_token]\n",
    "\n",
    "tgt_text_field = vocab_fields['tgt'].base_field\n",
    "tgt_vocab = tgt_text_field.vocab\n",
    "tgt_padding = tgt_vocab.stoi[tgt_text_field.pad_token]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Next we specify the core model itself. Here we will build a small model with an encoder and an attention based input feeding decoder. Both models will be RNNs and the encoder will be bidirectional"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [],
   "source": [
    "emb_size = 100\n",
    "rnn_size = 500\n",
    "# Specify the core model.\n",
    "\n",
    "encoder_embeddings = onmt.modules.Embeddings(emb_size, len(src_vocab),\n",
    "                                             word_padding_idx=src_padding)\n",
    "\n",
    "encoder = onmt.encoders.RNNEncoder(hidden_size=rnn_size, num_layers=1,\n",
    "                                   rnn_type=\"LSTM\", bidirectional=True,\n",
    "                                   embeddings=encoder_embeddings)\n",
    "\n",
    "decoder_embeddings = onmt.modules.Embeddings(emb_size, len(tgt_vocab),\n",
    "                                             word_padding_idx=tgt_padding)\n",
    "decoder = onmt.decoders.decoder.InputFeedRNNDecoder(\n",
    "    hidden_size=rnn_size, num_layers=1, bidirectional_encoder=True, \n",
    "    rnn_type=\"LSTM\", embeddings=decoder_embeddings)\n",
    "\n",
    "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
    "model = onmt.models.model.NMTModel(encoder, decoder)\n",
    "model.to(device)\n",
    "\n",
    "# Specify the tgt word generator and loss computation module\n",
    "model.generator = nn.Sequential(\n",
    "    nn.Linear(rnn_size, len(tgt_vocab)),\n",
    "    nn.LogSoftmax(dim=-1)).to(device)\n",
    "\n",
    "loss = onmt.utils.loss.NMTLossCompute(\n",
    "    criterion=nn.NLLLoss(ignore_index=tgt_padding, reduction=\"sum\"),\n",
    "    generator=model.generator)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now we set up the optimizer. This could be a core torch optim class, or our wrapper which handles learning rate updates and gradient normalization automatically."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [],
   "source": [
    "lr = 1\n",
    "torch_optimizer = torch.optim.SGD(model.parameters(), lr=lr)\n",
    "optim = onmt.utils.optimizers.Optimizer(\n",
    "    torch_optimizer, learning_rate=lr, max_grad_norm=2)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Create the training and validation data iterators"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now we need to create the dynamic dataset iterator.\n",
    "\n",
    "This is not very 'library-friendly' for now because of the way the `DynamicDatasetIter` constructor is defined. It may evolve in the future."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [],
   "source": [
    "src_train = \"toy-ende/src-train.txt\"\n",
    "tgt_train = \"toy-ende/tgt-train.txt\"\n",
    "src_val = \"toy-ende/src-val.txt\"\n",
    "tgt_val = \"toy-ende/tgt-val.txt\"\n",
    "\n",
    "# build the ParallelCorpus\n",
    "corpus = ParallelCorpus(\"corpus\", src_train, tgt_train)\n",
    "valid = ParallelCorpus(\"valid\", src_val, tgt_val)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [],
   "source": [
    "# build the training iterator\n",
    "train_iter = DynamicDatasetIter(\n",
    "    corpora={\"corpus\": corpus},\n",
    "    corpora_info={\"corpus\": {\"weight\": 1}},\n",
    "    transforms={},\n",
    "    fields=vocab_fields,\n",
    "    is_train=True,\n",
    "    batch_type=\"tokens\",\n",
    "    batch_size=4096,\n",
    "    batch_size_multiple=1,\n",
    "    data_type=\"text\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [],
   "source": [
    "# make sure the iteration happens on GPU 0 (-1 for CPU, N for GPU N)\n",
    "train_iter = iter(IterOnDevice(train_iter, 0))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [],
   "source": [
    "# build the validation iterator\n",
    "valid_iter = DynamicDatasetIter(\n",
    "    corpora={\"valid\": valid},\n",
    "    corpora_info={\"valid\": {\"weight\": 1}},\n",
    "    transforms={},\n",
    "    fields=vocab_fields,\n",
    "    is_train=False,\n",
    "    batch_type=\"sents\",\n",
    "    batch_size=8,\n",
    "    batch_size_multiple=1,\n",
    "    data_type=\"text\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [],
   "source": [
    "valid_iter = IterOnDevice(valid_iter, 0)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Training"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Finally we train."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[2020-09-25 15:28:15,184 INFO] Start training loop and validate every 500 steps...\n",
      "[2020-09-25 15:28:15,185 INFO] corpus's transforms: TransformPipe()\n",
      "[2020-09-25 15:28:15,187 INFO] Loading ParallelCorpus(toy-ende/src-train.txt, toy-ende/tgt-train.txt, align=None)...\n",
      "[2020-09-25 15:28:21,140 INFO] Step 50/ 1000; acc:   7.52; ppl: 8832.29; xent: 9.09; lr: 1.00000; 18916/18871 tok/s;      6 sec\n",
      "[2020-09-25 15:28:24,869 INFO] Loading ParallelCorpus(toy-ende/src-train.txt, toy-ende/tgt-train.txt, align=None)...\n",
      "[2020-09-25 15:28:27,121 INFO] Step 100/ 1000; acc:   9.34; ppl: 1840.06; xent: 7.52; lr: 1.00000; 18911/18785 tok/s;     12 sec\n",
      "[2020-09-25 15:28:33,048 INFO] Step 150/ 1000; acc:  10.35; ppl: 1419.18; xent: 7.26; lr: 1.00000; 19062/19017 tok/s;     18 sec\n",
      "[2020-09-25 15:28:37,019 INFO] Loading ParallelCorpus(toy-ende/src-train.txt, toy-ende/tgt-train.txt, align=None)...\n",
      "[2020-09-25 15:28:39,022 INFO] Step 200/ 1000; acc:  11.14; ppl: 1127.44; xent: 7.03; lr: 1.00000; 19084/18911 tok/s;     24 sec\n",
      "[2020-09-25 15:28:45,073 INFO] Step 250/ 1000; acc:  12.46; ppl: 912.13; xent: 6.82; lr: 1.00000; 18575/18570 tok/s;     30 sec\n",
      "[2020-09-25 15:28:49,301 INFO] Loading ParallelCorpus(toy-ende/src-train.txt, toy-ende/tgt-train.txt, align=None)...\n",
      "[2020-09-25 15:28:51,151 INFO] Step 300/ 1000; acc:  13.04; ppl: 779.50; xent: 6.66; lr: 1.00000; 18394/18307 tok/s;     36 sec\n",
      "[2020-09-25 15:28:57,316 INFO] Step 350/ 1000; acc:  14.04; ppl: 685.48; xent: 6.53; lr: 1.00000; 18339/18173 tok/s;     42 sec\n",
      "[2020-09-25 15:29:02,117 INFO] Loading ParallelCorpus(toy-ende/src-train.txt, toy-ende/tgt-train.txt, align=None)...\n",
      "[2020-09-25 15:29:03,576 INFO] Step 400/ 1000; acc:  14.99; ppl: 590.20; xent: 6.38; lr: 1.00000; 18090/18029 tok/s;     48 sec\n",
      "[2020-09-25 15:29:09,546 INFO] Step 450/ 1000; acc:  16.00; ppl: 524.51; xent: 6.26; lr: 1.00000; 18726/18536 tok/s;     54 sec\n",
      "[2020-09-25 15:29:14,585 INFO] Loading ParallelCorpus(toy-ende/src-train.txt, toy-ende/tgt-train.txt, align=None)...\n",
      "[2020-09-25 15:29:15,596 INFO] Step 500/ 1000; acc:  16.78; ppl: 453.38; xent: 6.12; lr: 1.00000; 17877/17980 tok/s;     60 sec\n",
      "[2020-09-25 15:29:15,597 INFO] valid's transforms: TransformPipe()\n",
      "[2020-09-25 15:29:15,599 INFO] Loading ParallelCorpus(toy-ende/src-val.txt, toy-ende/tgt-val.txt, align=None)...\n",
      "[2020-09-25 15:29:24,528 INFO] Validation perplexity: 295.1\n",
      "[2020-09-25 15:29:24,529 INFO] Validation accuracy: 17.6533\n",
      "[2020-09-25 15:29:30,592 INFO] Step 550/ 1000; acc:  17.47; ppl: 421.26; xent: 6.04; lr: 1.00000; 7726/7610 tok/s;     75 sec\n",
      "[2020-09-25 15:29:36,055 INFO] Loading ParallelCorpus(toy-ende/src-train.txt, toy-ende/tgt-train.txt, align=None)...\n",
      "[2020-09-25 15:29:36,695 INFO] Step 600/ 1000; acc:  18.95; ppl: 354.44; xent: 5.87; lr: 1.00000; 17470/17598 tok/s;     82 sec\n",
      "[2020-09-25 15:29:42,794 INFO] Step 650/ 1000; acc:  19.60; ppl: 328.47; xent: 5.79; lr: 1.00000; 18994/18793 tok/s;     88 sec\n",
      "[2020-09-25 15:29:48,635 INFO] Loading ParallelCorpus(toy-ende/src-train.txt, toy-ende/tgt-train.txt, align=None)...\n",
      "[2020-09-25 15:29:48,924 INFO] Step 700/ 1000; acc:  20.57; ppl: 285.48; xent: 5.65; lr: 1.00000; 17856/17788 tok/s;     94 sec\n",
      "[2020-09-25 15:29:54,898 INFO] Step 750/ 1000; acc:  21.97; ppl: 249.06; xent: 5.52; lr: 1.00000; 19030/18924 tok/s;    100 sec\n",
      "[2020-09-25 15:30:01,233 INFO] Step 800/ 1000; acc:  22.66; ppl: 228.54; xent: 5.43; lr: 1.00000; 17571/17471 tok/s;    106 sec\n",
      "[2020-09-25 15:30:01,357 INFO] Loading ParallelCorpus(toy-ende/src-train.txt, toy-ende/tgt-train.txt, align=None)...\n",
      "[2020-09-25 15:30:07,345 INFO] Step 850/ 1000; acc:  24.32; ppl: 193.65; xent: 5.27; lr: 1.00000; 18344/18313 tok/s;    112 sec\n",
      "[2020-09-25 15:30:11,363 INFO] Loading ParallelCorpus(toy-ende/src-train.txt, toy-ende/tgt-train.txt, align=None)...\n",
      "[2020-09-25 15:30:13,487 INFO] Step 900/ 1000; acc:  24.93; ppl: 177.65; xent: 5.18; lr: 1.00000; 18293/18105 tok/s;    118 sec\n",
      "[2020-09-25 15:30:19,670 INFO] Step 950/ 1000; acc:  26.33; ppl: 157.10; xent: 5.06; lr: 1.00000; 17791/17746 tok/s;    124 sec\n",
      "[2020-09-25 15:30:24,072 INFO] Loading ParallelCorpus(toy-ende/src-train.txt, toy-ende/tgt-train.txt, align=None)...\n",
      "[2020-09-25 15:30:25,820 INFO] Step 1000/ 1000; acc:  27.47; ppl: 137.64; xent: 4.92; lr: 1.00000; 17942/17962 tok/s;    131 sec\n",
      "[2020-09-25 15:30:25,822 INFO] Loading ParallelCorpus(toy-ende/src-val.txt, toy-ende/tgt-val.txt, align=None)...\n",
      "[2020-09-25 15:30:34,665 INFO] Validation perplexity: 241.801\n",
      "[2020-09-25 15:30:34,666 INFO] Validation accuracy: 20.2837\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<onmt.utils.statistics.Statistics at 0x7fca934e8e80>"
      ]
     },
     "execution_count": 28,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "report_manager = onmt.utils.ReportMgr(\n",
    "    report_every=50, start_time=None, tensorboard_writer=None)\n",
    "\n",
    "trainer = onmt.Trainer(model=model,\n",
    "                       train_loss=loss,\n",
    "                       valid_loss=loss,\n",
    "                       optim=optim,\n",
    "                       report_manager=report_manager,\n",
    "                       dropout=[0.1])\n",
    "\n",
    "trainer.train(train_iter=train_iter,\n",
    "              train_steps=1000,\n",
    "              valid_iter=valid_iter,\n",
    "              valid_steps=500)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Translate"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "For translation, we can build a \"traditional\" (as opposed to dynamic) dataset for now."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [],
   "source": [
    "src_data = {\"reader\": onmt.inputters.str2reader[\"text\"](), \"data\": src_val}\n",
    "tgt_data = {\"reader\": onmt.inputters.str2reader[\"text\"](), \"data\": tgt_val}\n",
    "_readers, _data = onmt.inputters.Dataset.config(\n",
    "    [('src', src_data), ('tgt', tgt_data)])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = onmt.inputters.Dataset(\n",
    "    vocab_fields, readers=_readers, data=_data,\n",
    "    sort_key=onmt.inputters.str2sortkey[\"text\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [],
   "source": [
    "data_iter = onmt.inputters.OrderedIterator(\n",
    "            dataset=dataset,\n",
    "            device=\"cuda\",\n",
    "            batch_size=10,\n",
    "            train=False,\n",
    "            sort=False,\n",
    "            sort_within_batch=True,\n",
    "            shuffle=False\n",
    "        )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [],
   "source": [
    "src_reader = onmt.inputters.str2reader[\"text\"]\n",
    "tgt_reader = onmt.inputters.str2reader[\"text\"]\n",
    "scorer = GNMTGlobalScorer(alpha=0.7, \n",
    "                          beta=0., \n",
    "                          length_penalty=\"avg\", \n",
    "                          coverage_penalty=\"none\")\n",
    "gpu = 0 if torch.cuda.is_available() else -1\n",
    "translator = Translator(model=model, \n",
    "                        fields=vocab_fields, \n",
    "                        src_reader=src_reader, \n",
    "                        tgt_reader=tgt_reader, \n",
    "                        global_scorer=scorer,\n",
    "                        gpu=gpu)\n",
    "builder = onmt.translate.TranslationBuilder(data=dataset, \n",
    "                                            fields=vocab_fields)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**Note**: translations will be very poor, because of the very low quantity of data, the absence of proper tokenization, and the brevity of the training."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "SENT 0: ['Parliament', 'Does', 'Not', 'Support', 'Amendment', 'Freeing', 'Tymoshenko']\n",
      "PRED 0: Parlament das Parlament über die Europäische Parlament , die sich in der Lage in der Lage ist , die es in der Lage sind .\n",
      "PRED SCORE: -1.5935\n",
      "\n",
      "\n",
      "SENT 0: ['Today', ',', 'the', 'Ukraine', 'parliament', 'dismissed', ',', 'within', 'the', 'Code', 'of', 'Criminal', 'Procedure', 'amendment', ',', 'the', 'motion', 'to', 'revoke', 'an', 'article', 'based', 'on', 'which', 'the', 'opposition', 'leader', ',', 'Yulia', 'Tymoshenko', ',', 'was', 'sentenced', '.']\n",
      "PRED 0: In der Nähe des Hotels , die in der Lage , die sich in der Lage ist , in der Lage , die in der Lage , die in der Lage ist .\n",
      "PRED SCORE: -1.7173\n",
      "\n",
      "\n",
      "SENT 0: ['The', 'amendment', 'that', 'would', 'lead', 'to', 'freeing', 'the', 'imprisoned', 'former', 'Prime', 'Minister', 'was', 'revoked', 'during', 'second', 'reading', 'of', 'the', 'proposal', 'for', 'mitigation', 'of', 'sentences', 'for', 'economic', 'offences', '.']\n",
      "PRED 0: Die Tatsache , die sich in der Lage in der Lage ist , die für eine Antwort der Entwicklung für die Entwicklung von Präsident .\n",
      "PRED SCORE: -1.6834\n",
      "\n",
      "\n",
      "SENT 0: ['In', 'October', ',', 'Tymoshenko', 'was', 'sentenced', 'to', 'seven', 'years', 'in', 'prison', 'for', 'entering', 'into', 'what', 'was', 'reported', 'to', 'be', 'a', 'disadvantageous', 'gas', 'deal', 'with', 'Russia', '.']\n",
      "PRED 0: In der Nähe wurde die Menschen in der Lage ist , die sich in der Lage <unk> .\n",
      "PRED SCORE: -1.5765\n",
      "\n",
      "\n",
      "SENT 0: ['The', 'verdict', 'is', 'not', 'yet', 'final;', 'the', 'court', 'will', 'hear', 'Tymoshenko', '&apos;s', 'appeal', 'in', 'December', '.']\n",
      "PRED 0: Es ist nicht der Fall , die in der Lage in der Lage sind .\n",
      "PRED SCORE: -1.3287\n",
      "\n",
      "\n",
      "SENT 0: ['Tymoshenko', 'claims', 'the', 'verdict', 'is', 'a', 'political', 'revenge', 'of', 'the', 'regime;', 'in', 'the', 'West', ',', 'the', 'trial', 'has', 'also', 'evoked', 'suspicion', 'of', 'being', 'biased', '.']\n",
      "PRED 0: Um in der Lage ist auch eine Lösung Rolle .\n",
      "PRED SCORE: -1.3975\n",
      "\n",
      "\n",
      "SENT 0: ['The', 'proposal', 'to', 'remove', 'Article', '365', 'from', 'the', 'Code', 'of', 'Criminal', 'Procedure', ',', 'upon', 'which', 'the', 'former', 'Prime', 'Minister', 'was', 'sentenced', ',', 'was', 'supported', 'by', '147', 'members', 'of', 'parliament', '.']\n",
      "PRED 0: Der Vorschlag , die in der Lage , die in der Lage , die in der Lage ist , war er von der Fall <unk> wurde .\n",
      "PRED SCORE: -1.6062\n",
      "\n",
      "\n",
      "SENT 0: ['Its', 'ratification', 'would', 'require', '226', 'votes', '.']\n",
      "PRED 0: Es wäre noch einmal noch einmal <unk> .\n",
      "PRED SCORE: -1.8001\n",
      "\n",
      "\n",
      "SENT 0: ['Libya', '&apos;s', 'Victory']\n",
      "PRED 0: In der Nähe des Hotels befindet sich in der Nähe des Hotels in der Lage .\n",
      "PRED SCORE: -1.7097\n",
      "\n",
      "\n",
      "SENT 0: ['The', 'story', 'of', 'Libya', '&apos;s', 'liberation', ',', 'or', 'rebellion', ',', 'already', 'has', 'its', 'defeated', '.']\n",
      "PRED 0: In der Nähe des Hotels in der Lage ist in der Lage .\n",
      "PRED SCORE: -1.7885\n",
      "\n"
     ]
    }
   ],
   "source": [
    "for batch in data_iter:\n",
    "    trans_batch = translator.translate_batch(\n",
    "        batch=batch, src_vocabs=[src_vocab],\n",
    "        attn_debug=False)\n",
    "    translations = builder.from_batch(trans_batch)\n",
    "    for trans in translations:\n",
    "        print(trans.log(0))\n",
    "    break"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
